#!/usr/bin/env python
# _*_coding:utf-8_*_

"""
@Software: PyCharm
@Author:  zhaojianghua
@Email: zhaojianghua1990@qq.com
@Homepage: None
"""

import random
import copy
import collections
import bisect
import json

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np


def swish(x):
    return x / (1 + torch.exp(-x))


def one_hot(labels, dim_size, dtype=torch.float32):
    if isinstance(labels, np.ndarray):
        labels = torch.tensor(labels, dtype=torch.int64)
    if labels.ndim == 1:
        labels.unsqueeze_(dim=1)
    arr = torch.zeros([labels.nelement(), dim_size],
                      dtype=dtype)
    arr.scatter_(dim=1, index=labels, value=1)
    return arr


active_fn_map = {
    "tanh": F.tanh,
    "relu": F.relu,
    "swish": swish
}


class Model(nn.Module):

    def __init__(self, active_fn, dims):
        super().__init__()
        self.active_fn = active_fn_map[active_fn]
        self.fc_list = []
        for i in range(len(dims)-1):
            self.fc_list.append(nn.Linear(dims[i], dims[i+1]))
            setattr(self, "fc%s"%i, self.fc_list[-1])

    def forward(self, inputs):
        x = inputs
        for fc in self.fc_list[:-1]:
            x = self.active_fn(fc(x))
        x = F.softmax(self.fc_list[-1](x), dim=-1)
        return x


def create_optimizer(net, opt_type, **kwargs):
    if opt_type.upper() == "SGD":
        return optim.SGD(net.parameters(),
                         **kwargs)
    else:
        raise ValueError("Not support optimizer:", opt_type)


class PG(object):
    """
    Optim target T = Sum_policy( sum(r_t) * p(s_0,a_0,s_1,a_1,...,s_n,a_n) ) = Expect_policy( sum(r_t) )
                   = mean( sum(r_t) )

                dT = sum(r_t) * p(s_0,a_0,s_1,a_1,...,s_n,a_n)
                   = sum(r_t) * p(s_0,a_0,s_1,a_1,...,s_n,a_n) * dp(s_0,a_0,s_1,a_1,...,s_n,a_n) / p(s_0,a_0,s_1,a_1,...,s_n,a_n)
                   = sum(r_t) * p(s_0,a_0,s_1,a_1,...,s_n,a_n) * dlog(p(s_0,a_0,s_1,a_1,...,s_n,a_n))
                   = sum(r_t) * p(s_0,a_0,s_1,a_1,...,s_n,a_n) * sum(dlog(p(a_i|s_i)
                   = Expect_policy( sum(r_t) * sum(dlog(p(a_i|s_i))) )
                   = mean( sum(r_t) * sum(dlog(p(a_i|s_i))) )
                   = mean( G_0 * sum(dlog(p(a_i|s_i))) )
                   = mean( sum(G_0 * dlog(p(a_i|s_i))) )
    """

    def __init__(self, env, model_params, optimizer_params, gamma=0.9,
                 loss_type="cross_advantageDecaySum"):
        self.env = env
        self.net = Model(**model_params)
        self.optimizer = create_optimizer(self.net, **optimizer_params)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, 500, gamma=0.5)
        self.episode_policy_record = []
        self.mean_total_reward = 0.0
        self.mean_decay_reward = 0.0
        self.gamma = gamma
        self.loss_type = loss_type

    def action(self, observation, pr=False):
        """ Take action on observation """
        self.net.eval()
        with torch.no_grad():
            inputs = torch.tensor([observation], dtype=torch.float32)
            probs = self.net(inputs)
            if pr:
                print(inputs.detach().numpy(), probs.detach().numpy())
            p = probs.numpy()[0]
            act = np.random.choice(2, 1, p=p)
            act = act[0]
        self.net.train()
        return act

    def update(self, observation, action, reward, observation_next, done):
        """ Learn """
        self.episode_policy_record.append((observation, action, reward))
        if done:
            obs_batch = [x[0] for x in self.episode_policy_record]
            obs_batch = np.asarray(obs_batch, dtype=np.float32)
            act_batch = [x[1] for x in self.episode_policy_record]
            act_batch = np.asarray(act_batch, dtype=np.int32)
            total_reward = sum(x[2] for x in self.episode_policy_record)
            cumsum_rewards = []
            r = 0
            for i in range(len(act_batch)-1, -1, -1):
                r = self.episode_policy_record[i][2] + self.gamma * r
                cumsum_rewards.insert(0, r)
            cumsum_rewards = np.asarray(cumsum_rewards, dtype=np.float32)

            self.optimizer.zero_grad()
            probs = self.net(torch.tensor(obs_batch))

            act_batch = one_hot(act_batch, 2)

            if self.loss_type.startswith("cross"):
                loss = act_batch * torch.log(probs + 1e-9) \
                       + (1 - act_batch) * torch.log(1 - probs + 1e-9)
            elif self.loss_type.startswith("log"):
                loss = act_batch * torch.log(probs + 1e-9)
            else:
                raise ValueError("Unknown loss type:", self.loss_type)

            loss = loss.sum(dim=1)

            if self.loss_type.endswith("sum"):
                loss = - total_reward * loss
            elif self.loss_type.endswith("advantageSum"):
                loss = - (total_reward - self.mean_total_reward) * loss
            elif self.loss_type.endswith("decaySum"):
                loss = - loss * torch.tensor(cumsum_rewards)
            elif self.loss_type.endswith("advantageDecaySum"):
                loss = - loss * torch.tensor(cumsum_rewards - self.mean_decay_reward)
            else:
                raise ValueError("Unknown loss type:", self.loss_type)

            self.mean_total_reward += 0.1 * (total_reward - self.mean_total_reward)
            self.mean_decay_reward += 0.1 * (cumsum_rewards.mean() - self.mean_decay_reward)

            loss = loss.mean()

            loss.backward()
            self.optimizer.step()
            # self.scheduler.step()

            self.episode_policy_record.clear()


def evaluate(agent, times=10, pr_final=False):
    env = gym.make("CartPole-v0")
    total_reward_list = []
    for i in range(times):
        x = env.reset()
        total_reward = 0
        while True:
            action = agent.action(x, pr=(i == times-1 and pr_final))
            x, reward, done, info = env.step(action)
            total_reward += reward
            if done:
                break
        total_reward_list.append(total_reward)
    return sum(total_reward_list) / times


def train(agent_cfg, max_episode):
    env = gym.make("CartPole-v0")
    print(env.metadata)

    agent = PG(env, **agent_cfg)

    train_total_reward_list = []
    eval_total_reward_list = []

    steps = 0
    for i in range(max_episode):
        observation = env.reset()
        total_reward = 0
        while True:
            steps += 1
            action = agent.action(observation)

            observation_next, reward, done, info = env.step(action)  # 与环境交互，获得下一步的时刻

            agent.update(observation, action, reward, observation_next, done)
            observation = observation_next
            total_reward += reward
            if done:
                # agent.replays.move_to_slot(total_reward)
                break
        print("episode:", i, "steps:", steps, "total_reward:", total_reward)
        train_total_reward_list.append((i, steps, total_reward))
        if (i + 1) % 50 == 0:
            er = evaluate(agent, 10)
            eval_total_reward_list.append((i, steps, er))
            print("episode:", i, "evaluate:", er)

    fer = evaluate(agent, 100, pr_final=True)
    print("Final evaluate:", fer)
    return {"train_rewards": train_total_reward_list,
            "evaluate_rewards": eval_total_reward_list,
            "final_evaluate_rewards": (max_episode, steps, fer)}
    
    
def search_space():
    active_fn_space = ["tanh", "relu", "swish"]
    dims_space = [[4, 20, 2], [4, 64, 16, 2]]
    opt_type_space = ['sgd']
    lr_space = [1e-2 / (4 ** i) for i in range(11)]
    momentum_space = [0.1 * i for i in range(1, 10)][::-1]
    gamma_space = [0.1 * i for i in range(1, 10)][::-1]
    loss_type_space = ["%s_%s" % (x, y) for x in ('cross', 'log')
                       for y in ('sum', 'advantageSum', 'decaySum', 'advantageDecaySum')]

    for a_i, a in enumerate(gamma_space):
        for b_i, b in enumerate(momentum_space):
            for c_i, c in enumerate(opt_type_space):
                for d_i, d in enumerate(dims_space):
                    for e_i, e in enumerate(active_fn_space):
                        for f_i, f in enumerate(lr_space):
                            for g_i, g in enumerate(loss_type_space):
                                _cfg = {
                                    "model_params": {"active_fn": e, "dims": d},
                                    "optimizer_params": {"opt_type": c, "lr": f, "momentum": b},
                                    "gamma": a,
                                    "loss_type": g
                                }
                                yield (a_i, b_i, c_i, d_i, e_i, f_i, g_i), _cfg
        break


def search():
    cfg_index = 0
    repeat_num = 3
    for cfg_i, cfg in search_space():
        # if cfg_i < (0,0,0,0,2,0,0):
        #     continue
        logs = []
        for i in range(repeat_num):
            train_log = train(cfg, max_episode=2000)
            logs.append(train_log)
        res = {"cfg": cfg,
               "trainLogs": logs}
        cfg_i = "_".join("%02d" % x for x in cfg_i)
        with open("logs/trainLogs-%s.json" % cfg_i, "w") as f:
            json.dump(res, f)
        cfg_index += 1
    
    
if __name__ == "__main__":
    cfg = agent_cfg = {
        "model_params": {"active_fn": "relu", "dims": [4, 20, 2]},
        "optimizer_params":  {"opt_type": "sgd", "lr": 5e-4, "momentum": 0.9},
        "gamma": 0.9,
        "loss_type": "log_advantageSum"
    }
    train(cfg, max_episode=2000)


